// Copyright 2017 The Bazel Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// embed generates a .go file from the contents of a list of data files. It is
// invoked by go_embed_data as an action.
package main

import (
	"archive/tar"
	"archive/zip"
	"bufio"
	"errors"
	"flag"
	"fmt"
	"io"
	"log"
	"os"
	"path"
	"path/filepath"
	"strconv"
	"strings"
	"text/template"
	"unicode/utf8"
)

var headerTpl = template.Must(template.New("embed").Parse(`// Generated by go_embed_data for {{.Label}}. DO NOT EDIT.

package {{.Package}}

`))

var multiFooterTpl = template.Must(template.New("embed").Parse(`
var {{.Var}} = map[string]{{.Type}}{
{{- range $i, $f := .FoundSources}}
	{{$.Key $f}}: {{$.Var}}_{{$i}},
{{- end}}
}

`))

func main() {
	log.SetPrefix("embed: ")
	log.SetFlags(0) // don't print timestamps
	if err := run(os.Args); err != nil {
		log.Fatal(err)
	}
}

type configuration struct {
	Label, Package, Var      string
	Multi                    bool
	sources                  []string
	FoundSources             []string
	out, workspace           string
	flatten, unpack, strData bool
}

func (c *configuration) Type() string {
	if c.strData {
		return "string"
	} else {
		return "[]byte"
	}
}

func (c *configuration) Key(filename string) string {
	workspacePrefix := "external/" + c.workspace + "/"
	key := filepath.FromSlash(strings.TrimPrefix(filename, workspacePrefix))
	if c.flatten {
		key = path.Base(filename)
	}
	return strconv.Quote(key)
}

func run(args []string) error {
	c, err := newConfiguration(args)
	if err != nil {
		return err
	}

	f, err := os.Create(c.out)
	if err != nil {
		return err
	}
	defer f.Close()
	w := bufio.NewWriter(f)
	defer w.Flush()

	if err := headerTpl.Execute(w, c); err != nil {
		return err
	}

	if c.Multi {
		return embedMultipleFiles(c, w)
	}
	return embedSingleFile(c, w)
}

func newConfiguration(args []string) (*configuration, error) {
	var c configuration
	flags := flag.NewFlagSet("embed", flag.ExitOnError)
	flags.StringVar(&c.Label, "label", "", "Label of the rule being executed (required)")
	flags.StringVar(&c.Package, "package", "", "Go package name (required)")
	flags.StringVar(&c.Var, "var", "", "Variable name (required)")
	flags.BoolVar(&c.Multi, "multi", false, "Whether the variable is a map or a single value")
	flags.StringVar(&c.out, "out", "", "Go file to generate (required)")
	flags.StringVar(&c.workspace, "workspace", "", "Name of the workspace (required)")
	flags.BoolVar(&c.flatten, "flatten", false, "Whether to access files by base name")
	flags.BoolVar(&c.strData, "string", false, "Whether to store contents as strings")
	flags.BoolVar(&c.unpack, "unpack", false, "Whether to treat files as archives to unpack.")
	flags.Parse(args[1:])
	if c.Label == "" {
		return nil, errors.New("error: -label option not provided")
	}
	if c.Package == "" {
		return nil, errors.New("error: -package option not provided")
	}
	if c.Var == "" {
		return nil, errors.New("error: -var option not provided")
	}
	if c.out == "" {
		return nil, errors.New("error: -out option not provided")
	}
	if c.workspace == "" {
		return nil, errors.New("error: -workspace option not provided")
	}
	c.sources = flags.Args()
	if !c.Multi && len(c.sources) != 1 {
		return nil, fmt.Errorf("error: -multi flag not given, so want exactly one source; got %d", len(c.sources))
	}
	if c.unpack {
		if !c.Multi {
			return nil, errors.New("error: -multi flag is required for -unpack mode.")
		}
		for _, src := range c.sources {
			if ext := filepath.Ext(src); ext != ".zip" && ext != ".tar" {
				return nil, fmt.Errorf("error: -unpack flag expects .zip or .tar extension (got %q)", ext)
			}
		}
	}
	return &c, nil
}

func embedSingleFile(c *configuration, w io.Writer) error {
	dataBegin, dataEnd := "\"", "\"\n"
	if !c.strData {
		dataBegin, dataEnd = "[]byte(\"", "\")\n"
	}

	if _, err := fmt.Fprintf(w, "var %s = %s", c.Var, dataBegin); err != nil {
		return err
	}
	if err := embedFileContents(w, c.sources[0]); err != nil {
		return err
	}
	_, err := fmt.Fprint(w, dataEnd)
	return err
}

func embedMultipleFiles(c *configuration, w io.Writer) error {
	dataBegin, dataEnd := "\"", "\"\n"
	if !c.strData {
		dataBegin, dataEnd = "[]byte(\"", "\")\n"
	}

	if _, err := fmt.Fprint(w, "var (\n"); err != nil {
		return err
	}
	if err := findSources(c, func(i int, f io.Reader) error {
		if _, err := fmt.Fprintf(w, "\t%s_%d = %s", c.Var, i, dataBegin); err != nil {
			return err
		}
		if _, err := io.Copy(&escapeWriter{w}, f); err != nil {
			return err
		}
		if _, err := fmt.Fprint(w, dataEnd); err != nil {
			return err
		}
		return nil
	}); err != nil {
		return err
	}
	if _, err := fmt.Fprint(w, ")\n"); err != nil {
		return err
	}
	if err := multiFooterTpl.Execute(w, c); err != nil {
		return err
	}
	return nil
}

func findSources(c *configuration, cb func(i int, f io.Reader) error) error {
	if c.unpack {
		for _, filename := range c.sources {
			ext := filepath.Ext(filename)
			if ext == ".zip" {
				if err := findZipSources(c, filename, cb); err != nil {
					return err
				}
			} else if ext == ".tar" {
				if err := findTarSources(c, filename, cb); err != nil {
					return err
				}
			} else {
				panic("unknown archive extension: " + ext)
			}
		}
		return nil
	}
	for _, filename := range c.sources {
		f, err := os.Open(filename)
		if err != nil {
			return err
		}
		err = cb(len(c.FoundSources), bufio.NewReader(f))
		f.Close()
		if err != nil {
			return err
		}
		c.FoundSources = append(c.FoundSources, filename)
	}
	return nil
}

func findZipSources(c *configuration, filename string, cb func(i int, f io.Reader) error) error {
	r, err := zip.OpenReader(filename)
	if err != nil {
		return err
	}
	defer r.Close()
	for _, file := range r.File {
		f, err := file.Open()
		if err != nil {
			return err
		}
		err = cb(len(c.FoundSources), f)
		f.Close()
		if err != nil {
			return err
		}
		c.FoundSources = append(c.FoundSources, file.Name)
	}
	return nil
}

func findTarSources(c *configuration, filename string, cb func(i int, f io.Reader) error) error {
	tf, err := os.Open(filename)
	if err != nil {
		return err
	}
	defer tf.Close()
	reader := tar.NewReader(bufio.NewReader(tf))
	for {
		h, err := reader.Next()
		if err == io.EOF {
			return nil
		}
		if err != nil {
			return err
		}
		if h.Typeflag != tar.TypeReg {
			continue
		}
		if err := cb(len(c.FoundSources), &io.LimitedReader{
			R: reader,
			N: h.Size,
		}); err != nil {
			return err
		}
		c.FoundSources = append(c.FoundSources, h.Name)
	}
}

func embedFileContents(w io.Writer, filename string) error {
	f, err := os.Open(filename)
	if err != nil {
		return err
	}
	defer f.Close()

	_, err = io.Copy(&escapeWriter{w}, bufio.NewReader(f))
	return err
}

type escapeWriter struct {
	w io.Writer
}

func (w *escapeWriter) Write(data []byte) (n int, err error) {
	n = len(data)

	for err == nil && len(data) > 0 {
		// https://golang.org/ref/spec#String_literals: "Within the quotes, any
		// character may appear except newline and unescaped double quote. The
		// text between the quotes forms the value of the literal, with backslash
		// escapes interpreted as they are in rune literals […]."
		switch b := data[0]; b {
		case '\\':
			_, err = w.w.Write([]byte(`\\`))
		case '"':
			_, err = w.w.Write([]byte(`\"`))
		case '\n':
			_, err = w.w.Write([]byte(`\n`))

		case '\x00':
			// https://golang.org/ref/spec#Source_code_representation: "Implementation
			// restriction: For compatibility with other tools, a compiler may
			// disallow the NUL character (U+0000) in the source text."
			_, err = w.w.Write([]byte(`\x00`))

		default:
			// https://golang.org/ref/spec#Source_code_representation: "Implementation
			// restriction: […] A byte order mark may be disallowed anywhere else in
			// the source."
			const byteOrderMark = '\uFEFF'

			if r, size := utf8.DecodeRune(data); r != utf8.RuneError && r != byteOrderMark {
				_, err = w.w.Write(data[:size])
				data = data[size:]
				continue
			}

			_, err = fmt.Fprintf(w.w, `\x%02x`, b)
		}
		data = data[1:]
	}

	return n - len(data), err
}
